package gu.sql2java.manager.parser;

import static gu.sql2java.SimpleLog.log;
import static com.google.common.base.Preconditions.checkNotNull;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.UncheckedExecutionException;

import gu.sql2java.exception.RuntimeDaoException;
import gu.sql2java.manager.SqlFormatter;
import gu.sql2java.manager.parser.ParserSupport.SqlParserInfo;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.parser.CCJSqlParserDefaultVisitor;
import net.sf.jsqlparser.parser.CCJSqlParserVisitor;
import net.sf.jsqlparser.parser.SimpleNode;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.util.TablesNamesFinder;

/**
 * 基于{@link LoadingCache}实现SQL语句解析格式化缓存，提高SQL语句分析的效率
 * 
 * @author guyadong
 *
 */
public class StatementCache {
    /**
     * 基于抽象语法树(AST)遍历所有语法节点的接口实例
     */
    private final CCJSqlParserVisitor visitor;
    private SqlSyntaxNormalizer sqlSyntaxNormalizer;
    /**
     * 保存成功解析的SQL语句对应的解析对象
     */
    private final LoadingCache<String, SqlParserInfo> statementCache = CacheBuilder.newBuilder()
            /** 设置数据缓存有效期 */
            .expireAfterAccess(1, TimeUnit.HOURS)
            .build(new CacheLoader<String, SqlParserInfo>() {
                @Override
                public SqlParserInfo load(String key) throws Exception {
                	SqlParserInfo sqlParserInfo = ParserSupport.parse0(key, visitor, sqlSyntaxNormalizer);                	
                	return sqlParserInfo;
                }
            });
    /** 已经检出的危险SQL脚本集合 */
    private final ConcurrentMap<String,RuntimeDaoException> dangrousSqls = Maps.newConcurrentMap();
    /** 解析失败的SQL脚本集合 */
    private final ConcurrentMap<String,RuntimeDaoException> invalidSqls = Maps.newConcurrentMap();
    private final SqlInjectionAnalyzer injectAnalyzer;
	public StatementCache() {
        this((CCJSqlParserDefaultVisitor) null, null);
    }

    public StatementCache(CCJSqlParserDefaultVisitor vistor, SqlSyntaxNormalizer sqlSyntaxNormalizer) {
        this.visitor = vistor;
        this.injectAnalyzer = new SqlInjectionAnalyzer();
        this.sqlSyntaxNormalizer = sqlSyntaxNormalizer;
    }

    public StatementCache(CCJSqlParserVisitor vistor,SqlSyntaxNormalizer sqlSyntaxNormalizer) {
        this(new AstNodeVisitor(vistor), sqlSyntaxNormalizer);
    }

    public StatementCache(SqlFormatter sqlFormatter, SqlSyntaxNormalizer sqlSyntaxNormalizer) {
        this(new AstNodeVisitor(sqlFormatter), sqlSyntaxNormalizer);
    }
    public StatementCache injectCheckEnable(boolean enable){
        injectAnalyzer.injectCheckEnable(enable);
        return this;
    }
    /**
     * 解析SQL语句，解析成功返回保存解析数据的{@link SqlParserInfo}对象,
     * 否则将解析异常{@link net.sf.jsqlparser.JSQLParserException}封装到{@link RuntimeDaoException}抛出
     * 
     * @param sql
     * @param injectAnalyze 为{@code true}执行注入攻击分析
     */
    public SqlParserInfo parse(String sql, boolean injectAnalyze) {
    	if(null != sql){
    		RuntimeDaoException rde;
    		if(null != (rde = dangrousSqls.get(sql))){
    			throw rde;
    		}
    		if(null != (rde = invalidSqls.get(sql))){
    			throw rde;
    		}
    	}
        try {
        	SqlParserInfo sqlParserInfo = statementCache.get(sql);
        	return injectAnalyze ? injectAnalyzer.injectAnalyse(sqlParserInfo) : sqlParserInfo;
        } catch (ExecutionException | UncheckedExecutionException e) {
        	RuntimeDaoException rde = new RuntimeDaoException(e.getCause());
        	invalidSqls.put(sql,rde);
        	throw rde;
        }catch (InjectionAttackException e) {
        	RuntimeDaoException rde = new RuntimeDaoException(e);
        	dangrousSqls.put(sql,rde);
        	throw rde;
		}
    }

    /**
     * 解析SQL语句，解析成功则返回由{@link #visitor}归一化处理后的SQL语句,
     * 否则将解析异常{@link net.sf.jsqlparser.JSQLParserException}封装到{@link RuntimeDaoException}抛出
     * 
     * @param sql
     * @param injectAnalyze 为{@code true}执行注入攻击分析
     */
    public String normalize(String sql, boolean injectAnalyze) {
        return parse(sql, injectAnalyze).nativeSql;
    }
    /**
     * 调用{@link Connection#prepareStatement(String, int, int)}创建预编译SQL语句{@link PreparedStatement}对象,
     * 调用前执行{@link #normalize(String, boolean)}方法对SQL语句进行归一化处理和安全检查
     * @param c SQL connection
     * @param sql sql statement
     * @param injectAnalyze run injection attack analysis if true
     * @param debug output SQL statement to console if true
     * @param logPrefix prefix string for debug information
     * @param resultSetType see also {@link Connection#prepareStatement(String, int, int)}
     * @param resultSetConcurrency see also {@link Connection#prepareStatement(String, int, int)}
     * @throws SQLException
     * @see Connection#prepareStatement(String, int, int)
     */
    public PreparedStatement prepareStatement(Connection c, String sql, boolean injectAnalyze, boolean debug,
            String logPrefix,
            int resultSetType, int resultSetConcurrency) throws SQLException {
        sql = normalize(checkNotNull(sql, "sql is null"), injectAnalyze);
        if (debug) {
            log(logPrefix + " : " + sql);
        }
        return checkNotNull(c, "connection is null").prepareStatement(sql, resultSetType, resultSetConcurrency);
    }
    /**
     * 调用{@link Connection#prepareStatement(String, int)}创建预编译SQL语句{@link PreparedStatement}对象,
     * 调用前执行{@link #normalize(String, boolean)}方法对SQL语句进行归一化处理和安全检查
     * @param c SQL connection
     * @param sql sql statement
     * @param injectAnalyze run injection attack analysis if true
     * @param debug output SQL statement to console if true
     * @param logPrefix prefix string for debug information
     * @param autoGeneratedKeys see also {@link Connection#prepareStatement(String, int)}
     * @throws SQLException
     * @see Connection#prepareStatement(String, int)
     */
    public PreparedStatement prepareStatement(Connection c, String sql, boolean injectAnalyze, boolean debug,
            String logPrefix, int autoGeneratedKeys) throws SQLException {
        sql = normalize(checkNotNull(sql, "sql is null"), injectAnalyze);
        if (debug) {
            log("{} : {}", logPrefix, sql);
        }
        return checkNotNull(c, "connection is null").prepareStatement(sql, autoGeneratedKeys);
    }
    /**
     * 调用{@link Connection#prepareStatement(String)}创建预编译SQL语句{@link PreparedStatement}对象,
     * 调用前执行{@link #normalize(String, boolean)}方法对SQL语句进行归一化处理和安全检查
     * @param c SQL connection
     * @param sql sql statement
     * @param injectAnalyze run injection attack analysis if true
     * @param debug output SQL statement to console if true
     * @param logPrefix prefix string for debug information
     * @throws SQLException
     * @see Connection#prepareStatement(String)
     */
    public PreparedStatement prepareStatement(Connection c, String sql, boolean injectAnalyze, boolean debug, String logPrefix)
            throws SQLException {
        sql = normalize(checkNotNull(sql, "sql is null"), injectAnalyze);
        if (debug) {
            log(logPrefix + " : " + sql);
        }
        return checkNotNull(c, "connection is null").prepareStatement(sql);
    }

    /**
     * 基于抽象语法树(AST)遍历节点的{@link CCJSqlParserVisitor}封装
     * @author guyadong
     *
     */
    private static class AstNodeVisitor extends CCJSqlParserDefaultVisitor {
        private final CCJSqlParserVisitor visitor;

        AstNodeVisitor(CCJSqlParserVisitor visitor) {
            this.visitor = visitor;
        }
        AstNodeVisitor(NodeVisitor finder) {
            this(null == finder ? null : (node, data) -> {
                Object value = node.jjtGetValue();
                if(value instanceof Column ){
                    finder.visit((Column) value);
                }else if (value instanceof Table) {
                    finder.visit((Table) value);
                }else if (value instanceof SelectExpressionItem) {
                    finder.visit((SelectExpressionItem) value);
                }else if (value instanceof FromItem) {
                    finder.visit((FromItem)value);
                }                
                return data;
            });
        }
        AstNodeVisitor(SqlFormatter sqlFormatter) {
            this(null == sqlFormatter ? null : new NodeVisitor(sqlFormatter));
        }

        @Override
        public Object visit(SimpleNode node, Object data) {
            if (null != visitor) {
                visitor.visit(node, data);
            }
            return super.visit(node, data);
        }
    }
    
    /**
     * 基于{@link TablesNamesFinder}对象遍历所有对象的封装
     * @author guyadong
     *
     */
    private static class NodeVisitor extends TablesNamesFinder{
        private final SqlFormatter sqlFormatter;
        NodeVisitor(SqlFormatter sqlFormatter) {
            this.sqlFormatter = sqlFormatter;
            init(true);
        }
        private void visit(Alias alias) {
            if(null != sqlFormatter){
                if(null  != alias){
                    alias.setName(sqlFormatter.alias(alias.getName()));
                }
            }
        }
        void visit(FromItem fromItem) {
            if(null != sqlFormatter){
                if(null  != fromItem){
                    visit(fromItem.getAlias());
                }
            }
        }
        @Override
        public void visit(Column column) {
            /** 名为true,false(不区分大小写)的column视为布尔值,不做处理 */
            if(null != sqlFormatter && !ParserSupport.isBoolean(column)){
                column.setColumnName(sqlFormatter.columname(column.getColumnName()));
            }
            super.visit(column);
        }

        @Override
        public void visit(SelectExpressionItem item) {
            if(null != sqlFormatter){
                visit(item.getAlias());
            }
            super.visit(item);
        }

        @Override
        public void visit(Table table) {
            if(null != sqlFormatter){
                table.setName(sqlFormatter.columname(table.getName()));
                visit(table.getAlias());
            }
            super.visit(table);
        }
    }
}
